{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2b0453b9-e9da-4e1c-855c-cab94c9064ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.651 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import jieba\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 读取数据集，这里是直接联网读取，也可以通过下载文件，再读取\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n",
    "\n",
    "train_text = '。'.join(list(train_data[0]))\n",
    "train_words = jieba.lcut(train_text)\n",
    "\n",
    "cn_stopwords = ' '.join(pd.read_csv('https://mirror.coggle.club/stopwords/baidu_stopwords.txt', header=None)[0])\n",
    "train_words = [x for x in train_words if x not in cn_stopwords]\n",
    "train_words = [x for x in train_words if len(x) > 1]\n",
    "train_words = [x for x in train_words if not x.isdigit()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6557e4fe-5e10-4959-8ec7-d3fd155635e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "train_words_freq = Counter(train_words)\n",
    "train_words = [x for x in train_words if train_words_freq[x] >= 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1e78e280-ff02-4df4-9d20-092d2a8fa642",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_word_prior = {}\n",
    "for row in train_data.iloc[:].itertuples():\n",
    "    text, label = row[1], row[2]\n",
    "    words = jieba.lcut(text)\n",
    "    words = [x for x in words if x in train_words]\n",
    "\n",
    "    if len(words) == 0:\n",
    "        continue\n",
    "\n",
    "    for word in words:\n",
    "        if word not in train_word_prior:\n",
    "            train_word_prior[word] = {\"total\": 0}\n",
    "\n",
    "        if label not in  train_word_prior[word]:\n",
    "            train_word_prior[word][label] = 0\n",
    "\n",
    "        train_word_prior[word][label] += 1\n",
    "        train_word_prior[word][\"total\"] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d317a33-42da-456c-8015-e0b57d35138d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_word_prior = pd.DataFrame(train_word_prior).T\n",
    "train_word_prior.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f18d195b-8ac9-4b60-9fa2-14ce00777e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for category in train_data[1].unique():\n",
    "    train_word_prior[category] /= train_word_prior['total']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a511da43-9d5a-4b53-a538-339cdee6cee4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_59295/1433704415.py:2: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  train_word_prior.groupby('category').apply(lambda x: list(x.index))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "category\n",
       "Alarm-Update             [早上, 我定, 下午, 参加, 公司, 闹钟, 活动, 提醒, 创建, 周末, 上午, 取...\n",
       "Audio-Play               [故事, 小说, 广播剧, 英文版, 岳云鹏, 爆笑, 相声, 有声, 俄语, 第五章, 郭...\n",
       "Calendar-Query           [昨天, 农历, 我查, 星期, 几号, 告诉, 几月, 查查, 礼拜, 几是, 春节, 母...\n",
       "FilmTele-Play            [播放, 古装, 爱情, 电视剧, 一个, 推理, 一会, 地方, 导演, 赵丽颖, 麻烦,...\n",
       "HomeAppliance-Control    [空调, 客厅, 风速, 打开, 烤箱, 儿童房, 调高, 洗衣机, 停止, 工作, 模式,...\n",
       "Music-Play               [随便, 一首, 专辑, 单曲, 循环, 王菲, 钢琴曲, 随机, 治愈, 日语, 歌曲, ...\n",
       "Other                                     [永远, 电话, 笑话, 之间, 老婆, 不好, 漫画, 有人]\n",
       "Radio-Listen             [河南, 新闻广播, 新闻台, 交通, 广播电台, 经典音乐, 七点, 中央, 电台, 都市...\n",
       "TVProgram-Play           [播出, 卫视, 广西, 法治, CCTV11, 剧场, 开播, 文化, 结束, 早间, 贵...\n",
       "Travel-Query             [汽车票, 回家, 深圳, 武汉, 北京, 桂林, 飞机, 起飞, 快点, 三张, 成都, ...\n",
       "Video-Play               [挑战, 游戏, 视频, 和平, 精英, 花絮, 转播, 比赛, 现场, 世界, 年谍, 第...\n",
       "Weather-Query            [查询, 海南, 几级, 刮风, 几天, 山西, 明天, 衡水, 气温, 适合, 杭州, 香...\n",
       "dtype: object"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_word_prior['category'] = train_word_prior.columns[1:][train_word_prior.values[:, 1:].argmax(1)]\n",
    "train_word_prior.groupby('category').apply(lambda x: list(x.index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
